import torch, argparse, os, pickle
from posteriors import recover_docs
import scipy.io
import numpy as np
from torch.utils.data import DataLoader
from models_datasets import AttnCTMDataset, generate_attn_batch, AttentionModel
import math
import glob

## Set these to the same as used in experiment
num_words = 60
N_test = 200

# Set confidence interval for metric
confidence = 95

BATCH_SIZE = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

confidence_val = {90: 1.645, 95: 1.96, 99: 2.576}
z = confidence_val[confidence]

def top_k_ovelap(y_true, y_pred, k):
    assert y_true.shape == y_pred.shape
    overlap_rate = np.zeros(len(y_true))
    N, C = y_true.shape
    if k > C:
        print('error')
        return -1
    for i in range(N):
        overlap_rate[i] = len(set(np.argsort(y_pred[i])[-k:])& set(np.argsort(y_true[i])[-k:]))/k
    mean = overlap_rate.mean()
    std = overlap_rate.std()
    return round(mean, 4), round(z*std/math.sqrt(len(overlap_rate)), 4)

if __name__ == '__main__':
    filename_list = []
    for filename in glob.glob('savedmodels/*.pt'):
        weights_file = filename.split('/')[1]
        filename_list.append(weights_file)

    filename_list = sorted(filename_list)
    for weights_file in filename_list:
        if weights_file[5] in '0123456789':
            alpha = int(weights_file[5])
            data_path = f'data/alpha{alpha}.0_{num_words}wordsdoc'

            with open(os.path.join(data_path, "topics.pkl"), 'rb') as fname:
                A = pickle.load(fname)

            with open(os.path.join(data_path, f'True_Posterior_{N_test}.pkl'), 'rb') as f:
                L = pickle.load(f)
            with open(os.path.join(data_path, 'test_topics_dist.pkl'), 'rb') as f:
                prior = pickle.load(f)
            prior = np.array(prior)
            model = torch.load('savedmodels/%s'%weights_file)

            # get test documents
            token_file = os.path.join(data_path, f'bow_test_tokens_{N_test}')
            count_file = os.path.join(data_path, f'bow_test_counts_{N_test}')

            tokens = scipy.io.loadmat(token_file)['tokens'].squeeze()
            counts = scipy.io.loadmat(count_file)['counts'].squeeze()

            test_documents = recover_docs(tokens, counts)

            # get predictions (N, V)
            test_dataset = AttnCTMDataset(test_documents, [0] * len(test_documents))
            test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                         shuffle=False,
                                         collate_fn=generate_attn_batch)

            with torch.no_grad():
                all_preds = []
                model.to(device)
                model.eval()
                print("Making predictions")
                for text, _ in test_dataloader:
                    predictions = model.get_word_probability(text.to(device)).squeeze(1).float()
                    all_preds.append(predictions.cpu())

                print("Tensorizing result")
                output = torch.cat(all_preds)
                output = output.view(len(test_dataset), -1)

            # Estimate E(posterior)
            Eta = np.transpose(np.dot(np.linalg.pinv(np.transpose(A)),
                                      np.transpose(output)))
            Eta = Eta * (Eta > 0)
            row_sums = Eta.sum(axis=1)
            Eta = Eta / row_sums[:, np.newaxis]

            TV_data = np.sum(np.abs(Eta-L)/2, axis=1)
            mean_TV = TV_data.mean()
            std_TV = TV_data.std()
        
            if not os.path.isdir('recovered'):
                os.mkdir('recovered')
            np.save('recovered/'+weights_file+'_doc_representation.npy', output)
            np.save('recovered/'+weights_file+'_topic_posterior.npy', Eta)

            print('='*40)
            print('model weight file:', weights_file)
            print('Confidence level: %d'%confidence)
            print('TV confidence interval: %.4f +/- %.4f'%(mean_TV, z*std_TV/math.sqrt(len(test_documents))))
            print('MAP accuracy confidence interval: %.4f +/- %.4f'%top_k_ovelap(prior, Eta, 1))
            print('top 2 overlap confidence interval: %.4f +/- %.4f'%top_k_ovelap(prior, Eta, 2))
            print('top 4 overlap confidence interval: %.4f +/- %.4f'%top_k_ovelap(prior, Eta, 4))
            print('top 6 overlap confidence interval: %.4f +/- %.4f'%top_k_ovelap(prior, Eta, 6))
            print('='*40)
